Skip to content

Conversation

qyqc731
Copy link
Contributor

@qyqc731 qyqc731 commented Sep 16, 2025

What this PR does / why we need it?

Add new npu_fused_infer_attention_score op to improve perfomance in splitfuse cases and resolve long-seq mask problems .

  1. The original op's performance is suboptimal in certain scenarios, necessitating optimization through the new op (npu_fused_infer_attention_score)。
  2. For ultra-long sequences (128k), the original operator will allocate a large attn_mask, which consumes excessive CPU memory. In contrast, the new op supports a fixed-size compressed mask, effectively resolving this issue.

NOTE1: The current PR retains the original logic and uses a version check of the CANN package to determine whether the new op can be enabled. This ensures no impact on existing users. In future versions, this version check and the original logic will be deprecated, and the new op scheduling will be uniformly adopted.
NOTE2: This pr relies on future CANN version, which is not available now.
NOTE3: To enable the new op in chunked prefill, the parameter additional_config should be set like --additional-config '{"ascend_scheduler_config": {"enabled":true,"enable_chunked_prefill":true}}' \ at least.

Does this PR introduce any user-facing change?

No

How was this patch tested?

CI passed

Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@qyqc731 qyqc731 changed the title chunked prefill splitfuse算子接入 chunked prefill splitfuse op in Sep 16, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

这个 PR 主要是为了接入新的 splitfuse chunked prefill 算子。代码改动涉及 attention_v1.pymodel_runner_v1.py 两个文件。在 attention_v1.py 中,_forward_v1_style 函数的注意力计算从 _npu_paged_attention_splitfuse 切换到了 npu_fused_infer_attention_score。在 model_runner_v1.py 中,为 ChunkedPrefill 场景生成 attention mask 的逻辑被修改。

我的审查发现两个严重问题:

  1. attention_v1.py 中,传递给新算子的 actual_seq_lengths 参数值是错误的,使用了累积的 token 位置而不是序列长度,这会导致注意力计算错误。
  2. model_runner_v1.py 中,为 ChunkedPrefill 生成的 attention mask 使用了硬编码的尺寸 (2048, 2048),这使得代码很脆弱,当序列长度超过 2048 时会导致错误。

建议修复这两个严重问题以保证代码的正确性和健壮性。

num_kv_heads=self.num_kv_heads,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.query_start_loc[1:],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

actual_seq_lengths 参数被传递了 attn_metadata.query_start_loc[1:],它包含的是累积的 token 位置,而不是各个序列的长度。这很可能会导致错误的注意力计算结果。你应该使用 attn_metadata.query_lens,它才包含正确的序列长度,并确保它在正确的设备上。

Suggested change
actual_seq_lengths=attn_metadata.query_start_loc[1:],
actual_seq_lengths=attn_metadata.query_lens.to(query.device),

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plz consider use attn_metadata.query_lens

if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, position, self.dtype, self.device)
return torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Attention mask 使用了硬编码的尺寸 (2048, 2048)。这是一个魔法数字,使得实现不够健壮。如果批处理中任何序列的长度超过 2048,将导致不正确的掩码或越界错误。掩码的大小应该由模型配置的最大序列长度决定,以确保正确性并避免魔法数字。

Suggested change
return torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8)
return torch.triu(torch.ones(self.model_config.max_model_len, self.model_config.max_model_len), diagonal=1).to(torch.int8)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is because new op support compressed mask

@qyqc731 qyqc731 changed the title chunked prefill splitfuse op in chunked prefill, access splitfuse op Sep 16, 2025
@rjg-lyh
Copy link
Collaborator

rjg-lyh commented Sep 17, 2025

Could you integrate other scenarios, such as full FlashAttention, using the FIA interface as well, and provide the performance test results?


self.attn_mask_builder = AttentionMaskBuilder(
self.model_config.max_model_len, self.dtype)
self.model_config.max_model_len, self.dtype, self.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we pass self.model_config.max_model_len here we still have an attention mask of shape [max_model_len, max_model_len]. We may have 2 choices:

  1. Passing self.scheduler_config.max_num_batched_tokens instead of self.model_config.max_model_len.
  2. Also using new attention op for full prefill case.

num_kv_heads=self.num_kv_heads,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.query_start_loc[1:],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plz consider use attn_metadata.query_lens

@Angazenn
Copy link
Contributor

Also notice another 2 points:

  1. compatibility with difference torch_npu/cann version.
  2. fix bug when padding is introduced in dp (len of query exceeds attn_metadata.query_lens[-1]).

Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

num_heads=self.num_heads,
scale_value=self.scale,
out=output)
if self.compressed_mask:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add TODO: This op will be used in more situation in future.

return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, position, self.dtype, self.device)
if selkf.compressed_mas:
return self.attn_mask_builder.get_splitfuse_attn_mask()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code bug

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

raise ValueError("Invalid type for tensors")


def verify_torch_npu_version(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cann version needs to be verfied. using package.version

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add todo for temp

# so device needs to be passed here.
assigned_mask_dim = 2048
self.chunked_prefill_attn_mask = torch.triu(torch.ones(assigned_mask_dim, assigned_mask_dim), diagonal=1
).to(torch.int8).to(device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check type again

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this op currently only supports int8 mask


self.attn_mask_builder = AttentionMaskBuilder(
self.model_config.max_model_len, self.dtype)
self.model_config.max_model_len, self.dtype, self.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for choice 1:

Suggested change
self.model_config.max_model_len, self.dtype, self.device)
self.scheduler_config.max_num_batched_tokens, self.dtype, self.device)

if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, position, self.dtype, self.device)
return torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is because new op support compressed mask

# so device needs to be passed here.
assigned_mask_dim = 2048
self.chunked_prefill_attn_mask = torch.triu(torch.ones(assigned_mask_dim, assigned_mask_dim), diagonal=1
).to(torch.int8).to(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this op currently only supports int8 mask

qyqc731 and others added 8 commits September 19, 2025 22:56
Signed-off-by: tangtianyi <tangtianyi4@huawei.com>
Signed-off-by: tangtianyi <tangtianyi4@huawei.com>
Signed-off-by: tangtianyi <tangtianyi4@huawei.com>
Signed-off-by: tangtianyi <tangtianyi4@huawei.com>
Signed-off-by: tangtianyi <tangtianyi4@huawei.com>
Signed-off-by: tangtianyi <tangtianyi4@huawei.com>
Signed-off-by: tangtianyi <tangtianyi4@huawei.com>
Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: Angazenn <supperccell@163.com>
@MengqingCao MengqingCao added ready read for review ready-for-test start test by label for PR labels Sep 20, 2025
@wangxiyuan wangxiyuan added ready-for-test start test by label for PR and removed ready-for-test start test by label for PR labels Sep 22, 2025
@qyqc731 qyqc731 changed the title chunked prefill, access splitfuse op [Perf] Add new npu_fused_infer_attention_score op to improve perfomance in splitfuse cases and resolve long-seq mask problems Sep 22, 2025
Copy link
Collaborator

@MengqingCao MengqingCao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this work! just note: we need to remember to upgrade torch-npu and cann together in ci later

@MengqingCao MengqingCao merged commit f1f2c8f into vllm-project:main Sep 22, 2025
59 of 62 checks passed
Mercykid-bash pushed a commit to Mercykid-bash/vllm-ascend that referenced this pull request Sep 22, 2025
…ce in splitfuse cases and resolve long-seq mask problems (vllm-project#2962)

### What this PR does / why we need it?
Add new npu_fused_infer_attention_score op to improve perfomance in
splitfuse cases and resolve long-seq mask problems .

1. The original op's performance is suboptimal in certain scenarios,
necessitating optimization through the _new op_
(npu_fused_infer_attention_score)。
2. For ultra-long sequences (128k), the original operator will allocate
a large attn_mask, which consumes excessive CPU memory. In contrast, the
_new op_ supports a fixed-size compressed mask, effectively resolving
this issue.

NOTE1: The current PR retains the original logic and uses a version
check of the CANN package to determine whether the _new op_ can be
enabled. This ensures no impact on existing users. In future versions,
this version check and the original logic will be deprecated, and the
_new op_ scheduling will be uniformly adopted.
NOTE2: This pr relies on future CANN version, which is not available
now.
NOTE3: To enable the new op in chunked prefill, the parameter
additional_config should be set like `--additional-config
'{"ascend_scheduler_config":
{"enabled":true,"enable_chunked_prefill":true}}' \` at least.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI passed

- vLLM version: v0.10.2
- vLLM main:
vllm-project/vllm@6c5f82e

---------

Signed-off-by: tangtianyi <tangtianyi4@huawei.com>
Signed-off-by: Angazenn <supperccell@163.com>
Co-authored-by: Angazenn <supperccell@163.com>
Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Mercykid-bash pushed a commit to Mercykid-bash/vllm-ascend that referenced this pull request Sep 22, 2025
…ce in splitfuse cases and resolve long-seq mask problems (vllm-project#2962)

### What this PR does / why we need it?
Add new npu_fused_infer_attention_score op to improve perfomance in
splitfuse cases and resolve long-seq mask problems .

1. The original op's performance is suboptimal in certain scenarios,
necessitating optimization through the _new op_
(npu_fused_infer_attention_score)。
2. For ultra-long sequences (128k), the original operator will allocate
a large attn_mask, which consumes excessive CPU memory. In contrast, the
_new op_ supports a fixed-size compressed mask, effectively resolving
this issue.

NOTE1: The current PR retains the original logic and uses a version
check of the CANN package to determine whether the _new op_ can be
enabled. This ensures no impact on existing users. In future versions,
this version check and the original logic will be deprecated, and the
_new op_ scheduling will be uniformly adopted.
NOTE2: This pr relies on future CANN version, which is not available
now.
NOTE3: To enable the new op in chunked prefill, the parameter
additional_config should be set like `--additional-config
'{"ascend_scheduler_config":
{"enabled":true,"enable_chunked_prefill":true}}' \` at least.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI passed

- vLLM version: v0.10.2
- vLLM main:
vllm-project/vllm@6c5f82e

---------

Signed-off-by: tangtianyi <tangtianyi4@huawei.com>
Signed-off-by: Angazenn <supperccell@163.com>
Co-authored-by: Angazenn <supperccell@163.com>
Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready read for review ready-for-test start test by label for PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants